% This m-file is for fitting single emitter spots with a vector PSF model
% This includes standard 2D-fitting for xy, and extensions to 3D-fitting
% for xyz, and to xylambda and even to 4D-fitting for xyzlambda, and to
% aberration retrieval from through-focus stacks.
%
% Sjoerd Stallinga, TU Delft

% (C) Copyright 2018
% All rights reserved
% Department of Imaging Physics
% Faculty of Applied Sciences
% Delft University of Technology
% Delft, The Netherlands   

close all
clear all

%%
% set parameters

fig_flag = 1;
storeresults = 1;

parameters = set_parameters_aberrations;

addpath('./PSFlocalization')  

jj = 8; % data number
Nspot = 1;

% read experimental image
if jj < 10
    data_str = ['45nm_488_experimentpsfcontrol_00' num2str(jj) 'z']
else
    data_str = ['45nm_488_experimentpsfcontrol_0' num2str(jj) 'z']
end

% calibration parameters
gain = 0.68; % conversion ADU
offset = 101.45; % mean background

% do spot localization
% manually localized spot set x,y-coordinates in PSFlocalization\get_allspots.m
allspots = get_allspots(parameters,data_str,offset,gain);

pause(0.1)

%%
% do MLE fit

% sampling coordinates in pupil plane and image plane
[XPupil,YPupil,XImage,YImage] = get_coords(parameters);

% estimation of initial values for the fit parameters, use initial values
% as starting point
thetainit = initialvalues(allspots,XImage,YImage,parameters);

% make MLE fit, parameters are max # of iterations, tolerance limit for
% the stopping criterion, the assumed read noise variance, and the
% optimizer type
parameters.Nitermax = 75;
parameters.tollim = 1e-6;
parameters.varfit = 0;
parameters.optmethod = 'levenbergmarquardt';
% parameters.optmethod = 'newtonraphson';

[thetastore,meritstore,Hessianstore,numiters] = localization(allspots,thetainit,parameters);
thetafinal = squeeze(thetastore(:,:,end));

% Fisher-matrix and CRLB computation, and chi-square computation
% plotrange indicates which fit results are plotted
% plotrange = 1:Ncfg;
Fisherstore = zeros(parameters.numparams,parameters.numparams,parameters.Ncfg);
CRLBstore = zeros(parameters.numparams,parameters.Ncfg);
chisquarestore = zeros(1,parameters.Ncfg);

for jcfg = 1:parameters.Ncfg
    theta = thetafinal(:,jcfg);
    parameters.xemit = theta(1);
    parameters.yemit = theta(2);
    Nph = theta(parameters.numparams-1);
    bg = theta(parameters.numparams);
    switch parameters.fitmodel
        case 'xyz'
            parameters.zemit = theta(3);
        case 'xylambda'
            parameters.lambda = theta(3);
        case 'xyzlambda'
            parameters.zemit = theta(3);
            parameters.lambda = theta(4);
        case 'aberrations'
            parameters.zemit = theta(3);
            parameters.aberrations(:,3) = theta(4:parameters.numparams-2);
    end
    parameters.signalphotoncount = Nph;
    parameters.backgroundphotoncount = bg;
    
    [~,~,wavevector,wavevectorzimm,Waberration,allzernikes,PupilMatrix] = get_pupil_matrix(parameters);
    [~,~,FieldMatrix,FieldMatrixDerivatives] = ...
        get_field_matrix_derivatives(PupilMatrix,wavevector,wavevectorzimm,Waberration,allzernikes,parameters);
    [allPSFs,PSFderivatives] = get_psfs_derivatives(FieldMatrix,FieldMatrixDerivatives,parameters);
    [Fisher,CRLB] = get_fisher_crlb(allPSFs,PSFderivatives,parameters);
    imagecfg = squeeze(allspots(:,:,:,jcfg));
    mucfg = Nph*allPSFs+bg;
    [chisquare,chisquaremean,chisquarestd] = get_chisquare(allPSFs,imagecfg,parameters);
    Fisherstore(:,:,jcfg) = Fisher;
    CRLBstore(:,jcfg) = CRLB;
    chisquarestore(jcfg) = chisquare;
end

% average over random instances
if parameters.Ncfg>1
    thetamean = mean(thetafinal,2);
    thetastd = std(thetafinal,1,2);
    thetastd(Nphindex) = std(thetafinal(Nphindex,:)-Nphtrue,1,2);
    thetastd(bgindex) = std(thetafinal(bgindex,:)-bgtrue,1,2);
    CRLBmean = mean(CRLBstore,2);
    CRLBstd = std(CRLBstore,1,2);
end

%%
% plot results
if fig_flag
    switch parameters.fitmodel
        case 'xy'
            ylabels = {'X (nm)','Y (nm)','Photon count','Background/pixel'};
            yerrorlabels = {'X error (nm)','Y error (nm)','Photon count','Background/pixel'};
            ystdlabels = {'X standard deviation (nm)','Y standard deviation (nm)','N_{ph} standard deviation (nm)','bg standard deviation (nm)'};
        case 'xyz'
            ylabels = {'X (nm)','Y (nm)','Z (nm)','Photon count','Background/pixel'};
            yerrorlabels = {'X error (nm)','Y error (nm)','Z error (nm)','Photon count error','Background/pixel error'};
            ystdlabels = {'X standard deviation (nm)','Y standard deviation (nm)','Z standard deviation (nm)','N_{ph} standard deviation (nm)','bg standard deviation (nm)'};
        case 'xylambda'
            ylabels = {'X (nm)','Y (nm)','Wavelength (nm)','Photon count','Background/pixel'};
            yerrorlabels = {'X error (nm)','Y error (nm)','Wavelength error (nm)','Photon count error','Background/pixel error'};
            ystdlabels = {'X standard deviation (nm)','Y standard deviation (nm)','\lambda standard deviation (nm)','N_{ph} standard deviation (nm)','bg standard deviation (nm)'};
        case 'xyzlambda'
            ylabels = {'X (nm)','Y (nm)','Z (nm)','\lambda (nm)','Photon count','Background/pixel'};
            yerrorlabels = {'X error (nm)','Y error (nm)','Z error (nm)','\lambda error (nm)','Photon count error','Background/pixel error'};
            ystdlabels = {'X standard deviation (nm)','Y standard deviation (nm)','Z standard deviation (nm)','\lambda standard deviation (nm)','N_{ph} standard deviation (nm)','bg standard deviation (nm)'};
        case 'aberrations'
            ylabels = {'X (nm)','Y (nm)','Z (nm)','Photon count','Background/pixel'};
            yerrorlabels = {'X error (nm)','Y error (nm)','Z error (nm)','Photon count error','Background/pixel error'};
            ystdlabels = {'X standard deviation (nm)','Y standard deviation (nm)','Z standard deviation (nm)','N_{ph} standard deviation (nm)','bg standard deviation (nm)'};
    end
    
    % monitor convergence of iterative procedure
    showconvergence = 1;
    if showconvergence
        
        %% plot iterations and merit function
        scrsz = get(0,'ScreenSize');
        figure('Position',[3*scrsz(4)/8 3*scrsz(4)/8 2*scrsz(3)/4 3*scrsz(4)/8]);
        subplot(1,2,1)
        box on
        plot(numiters,'-or')
        xlabel('random instance')
        ylabel('# iterations to convergence')
        subplot(1,2,2)
        hold on
        box on
        plot(meritstore')
        xlabel('# iterations')
        ylabel('merit function')
        
        %% plot (x,y,z), photon count, background/pixel convergence
        scrsz = get(0,'ScreenSize');
        figure('Position',[1*scrsz(4)/8 1*scrsz(4)/8 3*scrsz(3)/4 3*scrsz(4)/4]);
        if strcmp(parameters.fitmodel,'aberrations')
            jall = [1,2,3,parameters.numparams-1,parameters.numparams];
        else
            jall = 1:parameters.numparams;
        end
        for j=1:length(jall)
            subplot(2,3,j)
            hold on;
            box on;
            plot(squeeze(thetastore(jall(j),:,:))');
            xlabel('# iterations')
            ylabel(ylabels{j});
        end
        
        %% plot Zernike modes convergence
        if strcmp(parameters.fitmodel,'aberrations')
            figure('Position',[1*scrsz(4)/8 1*scrsz(4)/8 3*scrsz(3)/4 3*scrsz(4)/4]);
            orders = parameters.aberrations(:,1:2);
            numzers = parameters.numparams-5;
            subplotsize = ceil(sqrt(numzers));
            for j=1:numzers
                jzer = 3+j;
                subplot(subplotsize,subplotsize,j)
                hold on;
                box on;
                plot(squeeze(thetastore(jzer,:,:))');
                %       plot(squeeze(thetastore(jzer,~outliers,:))');
                %       plot(squeeze(thetastore(jzer,outliers,:))','k','LineWidth',2);
                xlabel('# iterations')
                ylabel(strcat('A_{',num2str(orders(j,1)),',',num2str(orders(j,2)),'}'));
            end
        end
        
    end
    
    if strcmp(parameters.fitmodel,'aberrations')
        %% plot photon count in ROI (data/fit)
        numphotons_data = squeeze(sum(sum(allspots)));
        mu = parameters.signalphotoncount*allPSFs+parameters.backgroundphotoncount;
        numphotons_fit = squeeze(sum(sum(mu)));
        zall = linspace(parameters.zrange(1),parameters.zrange(2),parameters.Mz);
        scrsz = [1 1 1366 768];
        figure
        set(gcf,'Position',[0.2*scrsz(4) 0.25*scrsz(4) 0.7*scrsz(4) 0.4*scrsz(4)]);
        box on
        hold on
        plot(zall,numphotons_data,'or','MarkerSize',8)
        plot(zall,numphotons_fit,'x-r','MarkerSize',3,'LineWidth',2)
        xlim([-1100 1100])
        ylim([0 12e4])
        xlabel('z_{stage} (nm)')
        ylabel('photon count in ROI');
        set(gca,'FontSize',12)
        legend('data','fit','Location','South')
        
        %% plot rms Zernike modes (mlambda)
        scrsz = [1 1 1366 768];
        figure
        set(gcf,'Position',[0.2*scrsz(4) 0.25*scrsz(4) 2*scrsz(4) 0.65*scrsz(4)]);
        box on
        hold on
        orders = parameters.aberrations(:,1:2);
        numzers = parameters.numparams-5;
        allxticks = 1:numzers;
        allxticklabels = cell(numzers,1);
        for jzer = 1:numzers
            allxticklabels{jzer} = strcat(num2str(orders(jzer,1)),',',num2str(orders(jzer,2)));
        end
        jcfg = 1;
        Wrms = sqrt(sum(thetafinal(4:3+numzers,jcfg).^2));
        plot(1:numzers,thetafinal(4:3+numzers,jcfg),'o-r','MarkerSize',8,'LineWidth',2)
        plot(0:numzers+1,0*(0:numzers+1),'-k')
        xticks(allxticks)
        %     yticks([-50 -40 -30 -20 -10 0 10 20 30 40 50])
        xticklabels(allxticklabels)
        xtickangle(45)
        ylabel('rms Zernike mode (m{\lambda})');
        title(strcat('W_{rms} = ',num2str(Wrms,3),'m{\lambda}'))
        %     ylim([-50 50])
        xlim([0 numzers+1])
        set(gca,'FontSize',12)
    end
    
end


